{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 练习三 动态规划进行策略评估、策略迭代和价值迭代"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1 小型方格世界MDP建模"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "# 4*4 方格状态命名\n",
    "# 状态0和15为终止状态\n",
    "#  0  1  2  3\n",
    "#  4  5  6  7\n",
    "#  8  9 10  11\n",
    "# 12 13 14  15"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "S = [i for i in range(16)] # 状态空间\n",
    "A = [\"n\", \"e\", \"s\", \"w\"] # 行为空间\n",
    "# P,R,将由dynamics动态生成\n",
    "\n",
    "ds_actions = {\"n\": -4, \"e\": 1, \"s\": 4, \"w\": -1} # 行为对状态的改变\n",
    "\n",
    "def dynamics(s, a): # 环境动力学\n",
    "    '''模拟小型方格世界的环境动力学特征\n",
    "    Args:\n",
    "        s 当前状态 int 0 - 15\n",
    "        a 行为 str in ['n','e','s','w'] 分别表示北、东、南、西\n",
    "    Returns: tuple (s_prime, reward, is_end)\n",
    "        s_prime 后续状态\n",
    "        reward 奖励值\n",
    "        is_end 是否进入终止状态\n",
    "    '''\n",
    "    s_prime = s\n",
    "    if (s%4 == 0 and a == \"w\") or (s<4 and a == \"n\") \\\n",
    "        or ((s+1)%4 == 0 and a == \"e\") or (s > 11 and a == \"s\")\\\n",
    "        or s in [0, 15]:\n",
    "        pass\n",
    "    else:\n",
    "        ds = ds_actions[a]\n",
    "        s_prime = s + ds\n",
    "    reward = 0 if s in [0, 15] else -1\n",
    "    is_end = True if s in [0, 15] else False\n",
    "    return s_prime, reward, is_end\n",
    "\n",
    "def P(s, a, s1): # 状态转移概率函数\n",
    "    s_prime, _, _ = dynamics(s, a)\n",
    "    return s1 == s_prime\n",
    "\n",
    "def R(s, a): # 奖励函数\n",
    "    _, r, _ = dynamics(s, a)\n",
    "    return r\n",
    "\n",
    "gamma = 1.00\n",
    "MDP = S, A, R, P, gamma"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2 策略"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def uniform_random_pi(MDP = None, V = None, s = None, a = None):\n",
    "    _, A, _, _, _ = MDP\n",
    "    n = len(A)\n",
    "    return 0 if n == 0 else 1.0/n\n",
    "\n",
    "def greedy_pi(MDP, V, s, a):\n",
    "    S, A, P, R, gamma = MDP\n",
    "    max_v, a_max_v = -float('inf'), []\n",
    "    for a_opt in A:# 统计后续状态的最大价值以及到达到达该状态的行为（可能不止一个）\n",
    "        s_prime, reward, _ = dynamics(s, a_opt)\n",
    "        v_s_prime = get_value(V, s_prime)\n",
    "        if v_s_prime > max_v:\n",
    "            max_v = v_s_prime\n",
    "            a_max_v = [a_opt]\n",
    "        elif(v_s_prime == max_v):\n",
    "            a_max_v.append(a_opt)\n",
    "    n = len(a_max_v)\n",
    "    if n == 0: return 0.0\n",
    "    return 1.0/n if a in a_max_v else 0.0\n",
    "\n",
    "def get_pi(Pi, s, a, MDP = None, V = None):\n",
    "    return Pi(MDP, V, s, a)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 辅助函数\n",
    "def get_prob(P, s, a, s1): # 获取状态转移概率\n",
    "    return P(s, a, s1)\n",
    "\n",
    "def get_reward(R, s, a): # 获取奖励值\n",
    "    return R(s, a)\n",
    "\n",
    "def set_value(V, s, v): # 设置价值字典\n",
    "    V[s] = v\n",
    "    \n",
    "def get_value(V, s): # 获取状态价值\n",
    "    return V[s]\n",
    "\n",
    "def display_V(V): # 显示状态价值\n",
    "    for i in range(16):\n",
    "        print('{0:>6.2f}'.format(V[i]),end = \" \")\n",
    "        if (i+1) % 4 == 0:\n",
    "            print(\"\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. 策略评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compute_q(MDP, V, s, a):\n",
    "    '''根据给定的MDP，价值函数V，计算状态行为对s,a的价值qsa\n",
    "    '''\n",
    "    S, A, R, P, gamma = MDP\n",
    "    q_sa = 0\n",
    "    for s_prime in S:\n",
    "        q_sa += get_prob(P, s, a, s_prime) * get_value(V, s_prime)\n",
    "    q_sa = get_reward(R, s,a) + gamma * q_sa\n",
    "    return q_sa\n",
    "\n",
    "def compute_v(MDP, V, Pi, s):\n",
    "    '''给定MDP下依据某一策略Pi和当前状态价值函数V计算某状态s的价值\n",
    "    '''\n",
    "    S, A, R, P, gamma = MDP\n",
    "    v_s = 0\n",
    "    for a in A:\n",
    "        v_s += get_pi(Pi, s, a, MDP, V) * compute_q(MDP, V, s, a)\n",
    "    return v_s        \n",
    "\n",
    "def update_V(MDP, V, Pi):\n",
    "    '''给定一个MDP和一个策略，更新该策略下的价值函数V\n",
    "    '''\n",
    "    S, _, _, _, _ = MDP\n",
    "    V_prime = V.copy()\n",
    "    for s in S:\n",
    "        set_value(V_prime, s, compute_v(MDP, V_prime, Pi, s))\n",
    "    return V_prime\n",
    "\n",
    "\n",
    "def policy_evaluate(MDP, V, Pi, n):\n",
    "    '''使用n次迭代计算来评估一个MDP在给定策略Pi下的状态价值，初始时价值为V\n",
    "    '''\n",
    "    for i in range(n):\n",
    "        #print(\"====第{}次迭代====\".format(i+1))\n",
    "        V = update_V(MDP, V, Pi)\n",
    "        #display_V(V)\n",
    "    return V\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0.00 -14.00 -20.00 -22.00 \n",
      "-14.00 -18.00 -20.00 -20.00 \n",
      "-20.00 -20.00 -18.00 -14.00 \n",
      "-22.00 -20.00 -14.00   0.00 \n",
      "\n",
      "  0.00  -1.00  -2.00  -3.00 \n",
      " -1.00  -2.00  -3.00  -2.00 \n",
      " -2.00  -3.00  -2.00  -1.00 \n",
      " -3.00  -2.00  -1.00   0.00 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "V = [0  for _ in range(16)] # 状态价值\n",
    "V_pi = policy_evaluate(MDP, V, uniform_random_pi, 100)\n",
    "display_V(V_pi)\n",
    "\n",
    "V = [0  for _ in range(16)] # 状态价值\n",
    "V_pi = policy_evaluate(MDP, V, greedy_pi, 100)\n",
    "display_V(V_pi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4 策略迭代"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "def policy_iterate(MDP, V, Pi, n, m):\n",
    "    for i in range(m):\n",
    "        V = policy_evaluate(MDP, V, Pi, n)\n",
    "        Pi = greedy_pi # 第一次迭代产生新的价值函数后随机使用贪婪策略\n",
    "    return V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0.00  -1.00  -2.00  -3.00 \n",
      " -1.00  -2.00  -3.00  -2.00 \n",
      " -2.00  -3.00  -2.00  -1.00 \n",
      " -3.00  -2.00  -1.00   0.00 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "V = [0  for _ in range(16)] # 重置状态价值\n",
    "V_pi = policy_iterate(MDP, V, greedy_pi, 1, 100)\n",
    "display_V(V_pi)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 价值迭代"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 价值迭代得到最优状态价值过程\n",
    "def compute_v_from_max_q(MDP, V, s):\n",
    "    '''根据一个状态的下所有可能的行为价值中最大一个来确定当前状态价值\n",
    "    '''\n",
    "    S, A, R, P, gamma = MDP\n",
    "    v_s = -float('inf')\n",
    "    for a in A:\n",
    "        qsa = compute_q(MDP, V, s, a)\n",
    "        if qsa >= v_s:\n",
    "            v_s = qsa\n",
    "    return v_s\n",
    "\n",
    "def update_V_without_pi(MDP, V):\n",
    "    '''在不依赖策略的情况下直接通过后续状态的价值来更新状态价值\n",
    "    '''\n",
    "    S, _, _, _, _ = MDP\n",
    "    V_prime = V.copy()\n",
    "    for s in S:\n",
    "        set_value(V_prime, s, compute_v_from_max_q(MDP, V_prime, s))\n",
    "    return V_prime\n",
    "\n",
    "def value_iterate(MDP, V, n):\n",
    "    '''价值迭代\n",
    "    '''\n",
    "    for i in range(n):\n",
    "        V = update_V_without_pi(MDP, V)\n",
    "        display_V(V)\n",
    "    return V"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  0.00   0.00   0.00   0.00 \n",
      "  0.00   0.00   0.00   0.00 \n",
      "  0.00   0.00   0.00   0.00 \n",
      "  0.00   0.00   0.00   0.00 \n",
      "\n",
      "  0.00  -1.00  -1.00  -1.00 \n",
      " -1.00  -1.00  -1.00  -1.00 \n",
      " -1.00  -1.00  -1.00  -1.00 \n",
      " -1.00  -1.00  -1.00   0.00 \n",
      "\n",
      "  0.00  -1.00  -2.00  -2.00 \n",
      " -1.00  -2.00  -2.00  -2.00 \n",
      " -2.00  -2.00  -2.00  -1.00 \n",
      " -2.00  -2.00  -1.00   0.00 \n",
      "\n",
      "  0.00  -1.00  -2.00  -3.00 \n",
      " -1.00  -2.00  -3.00  -2.00 \n",
      " -2.00  -3.00  -2.00  -1.00 \n",
      " -3.00  -2.00  -1.00   0.00 \n",
      "\n",
      "  0.00  -1.00  -2.00  -3.00 \n",
      " -1.00  -2.00  -3.00  -2.00 \n",
      " -2.00  -3.00  -2.00  -1.00 \n",
      " -3.00  -2.00  -1.00   0.00 \n",
      "\n",
      "  0.00  -1.00  -2.00  -3.00 \n",
      " -1.00  -2.00  -3.00  -2.00 \n",
      " -2.00  -3.00  -2.00  -1.00 \n",
      " -3.00  -2.00  -1.00   0.00 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "V = [0  for _ in range(16)] # 重置状态价值\n",
    "display_V(V)\n",
    "\n",
    "V_star = value_iterate(MDP, V, 4)\n",
    "display_V(V_star)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def greedy_policy(MDP, V, s):\n",
    "    S, A, P, R, gamma = MDP\n",
    "    max_v, a_max_v = -float('inf'), []\n",
    "    for a_opt in A:# 统计后续状态的最大价值以及到达到达该状态的行为（可能不止一个）\n",
    "        s_prime, reward, _ = dynamics(s, a_opt)\n",
    "        v_s_prime = get_value(V, s_prime)\n",
    "        if v_s_prime > max_v:\n",
    "            max_v = v_s_prime\n",
    "            a_max_v = a_opt\n",
    "        elif(v_s_prime == max_v):\n",
    "            a_max_v += a_opt\n",
    "    return str(a_max_v)\n",
    "\n",
    "def display_policy(policy, MDP, V):\n",
    "    S, A, P, R, gamma = MDP\n",
    "    for i in range(16):\n",
    "        print('{0:^6}'.format(policy(MDP, V, S[i])),end = \" \")\n",
    "        if (i+1) % 4 == 0:\n",
    "            print(\"\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " nesw    w      w      sw   \n",
      "  n      nw    nesw    s    \n",
      "  n     nesw    es     s    \n",
      "  ne     e      e     nesw  \n",
      "\n"
     ]
    }
   ],
   "source": [
    "display_policy(greedy_policy, MDP, V_star)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
